{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-02T13:15:28.082069Z",
     "start_time": "2024-04-02T13:15:26.947498Z"
    },
    "id": "_YcCBE70c5rz"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<frozen importlib._bootstrap>:219: RuntimeWarning: scipy._lib.messagestream.MessageStream size changed, may indicate binary incompatibility. Expected 56 from C header, got 64 from PyObject\n"
     ]
    }
   ],
   "source": [
    "#All library imports\n",
    "import os\n",
    "import numpy as np\n",
    "import torch\n",
    "import numbers\n",
    "import ntpath\n",
    "import math\n",
    "import piq\n",
    "import lpips\n",
    "import functools\n",
    "from PIL import Image\n",
    "from torchvision import transforms, utils\n",
    "import cv2\n",
    "from torch.utils import data\n",
    "from skimage import io, transform\n",
    "import pandas as pd\n",
    "import json\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision.transforms.functional as TF\n",
    "\n",
    "import natsort, glob, pickle, torch\n",
    "from collections import OrderedDict\n",
    "from utils.util import opt_get\n",
    "import options.options as option\n",
    "from models import create_model\n",
    "from imresize import imresize\n",
    "import Measure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-02T13:15:28.323496Z",
     "start_time": "2024-04-02T13:15:28.316982Z"
    },
    "id": "X-gmQCbDYgll"
   },
   "outputs": [],
   "source": [
    "#Important function definitions\n",
    "\n",
    "def find_files(wildcard): return natsort.natsorted(glob.glob(wildcard, recursive=True))\n",
    "\n",
    "def imshow(array):\n",
    "    display(Image.fromarray(array))\n",
    "\n",
    "#from test import load_model, fiFindByWildcard, imread\n",
    "\n",
    "def pickleRead(path):\n",
    "    with open(path, 'rb') as f:\n",
    "        return pickle.load(f)\n",
    "\n",
    "def load_model(conf_path):\n",
    "    opt = option.parse(conf_path, is_train=False)\n",
    "    opt['gpu_ids'] = None\n",
    "    opt = option.dict_to_nonedict(opt)\n",
    "    model = create_model(opt)\n",
    "\n",
    "    model_path = opt_get(opt, ['model_path'], None)\n",
    "    model.load_network(load_path=model_path, network=model.netG)\n",
    "    return model, opt\n",
    "\n",
    "def fiFindByWildcard(wildcard):\n",
    "    return natsort.natsorted(glob.glob(wildcard, recursive=True))\n",
    "\n",
    "def imread(path):\n",
    "    return cv2.imread(path)[:, :, [2, 1, 0]]\n",
    "\n",
    "# Convert to tensor\n",
    "def t(array): return torch.Tensor(np.expand_dims(array.transpose([2, 0, 1]), axis=0).astype(np.float32)).to('cuda') / 255\n",
    "\n",
    "# convert to image\n",
    "def rgb(t): return (np.clip((t[0] if len(t.shape) == 4 else t).detach().cpu().numpy().transpose([1, 2, 0]), 0, 1) * 255).astype(np.uint8)\n",
    "\n",
    "# calculates mean, std, min, max values for IQA metrics\n",
    "def metric_stats(metric_dict_list):\n",
    "    metric_dict_arr = np.array(metric_dict_list)\n",
    "    stats_min = round(np.min(metric_dict_arr), 3)\n",
    "    stats_max = round(np.max(metric_dict_arr), 3)\n",
    "    stats_mean = round(np.mean(metric_dict_arr), 3)\n",
    "    stats_std = round(np.std(metric_dict_arr), 3)\n",
    "    \n",
    "    return [stats_min, stats_max, stats_mean, stats_std]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-02T13:15:29.473239Z",
     "start_time": "2024-04-02T13:15:29.459432Z"
    },
    "id": "9Y9yVsloYgn9"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['../pretrained_models/ESRGAN_DF2K_8X.pth',\n",
       " '../pretrained_models/RRDB_CelebA_8X.pth',\n",
       " '../pretrained_models/RRDB_DF2K_4X.pth',\n",
       " '../pretrained_models/RRDB_DF2K_8X.pth',\n",
       " '../pretrained_models/SRFlow_CelebA_8X.pth',\n",
       " '../pretrained_models/SRFlow_DF2K_4X.pth',\n",
       " '../pretrained_models/SRFlow_DF2K_8X.pth']"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "find_files(\"../pretrained_models/*.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-02T13:15:30.264163Z",
     "start_time": "2024-04-02T13:15:30.259625Z"
    },
    "id": "NFhMW-bBYgqJ"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['confs/RRDB_CelebA_8X.yml',\n",
       " 'confs/RRDB_DF2K_4X.yml',\n",
       " 'confs/RRDB_DF2K_8X.yml',\n",
       " 'confs/SRFlow_CelebA_8X.yml',\n",
       " 'confs/SRFlow_DF2K_4X.yml',\n",
       " 'confs/SRFlow_DF2K_8X.yml']"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "find_files(\"confs/*.yml\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-02T13:15:33.067007Z",
     "start_time": "2024-04-02T13:15:31.258622Z"
    },
    "id": "LPmc3-a5Ygsk"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OrderedDict([('manual_seed', 10), ('lr_G', 0.0005), ('weight_decay_G', 0), ('beta1', 0.9), ('beta2', 0.99), ('lr_scheme', 'MultiStepLR'), ('warmup_iter', -1), ('lr_steps_rel', [0.5, 0.75, 0.9, 0.95]), ('lr_gamma', 0.5), ('niter', 200000), ('val_freq', 40000), ('lr_steps', [100000, 150000, 180000, 190000])])\n",
      "32\n",
      "32\n",
      "0.5\n",
      "4\n",
      "16\n",
      "[1, 3, 5, 7]\n",
      "[1, 3, 5, 7]\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "auto\n",
      "../pretrained_models/SRFlow_CelebA_8X.pth\n",
      "OrderedDict([('manual_seed', 10), ('lr_G', 0.00025), ('weight_decay_G', 0), ('beta1', 0.9), ('beta2', 0.99), ('lr_scheme', 'MultiStepLR'), ('warmup_iter', -1), ('lr_steps_rel', [0.5, 0.75, 0.9, 0.95]), ('lr_gamma', 0.5), ('niter', 200000), ('val_freq', 40000), ('lr_steps', [100000, 150000, 180000, 190000])])\n",
      "32\n",
      "32\n",
      "0.5\n",
      "3\n",
      "16\n",
      "[1, 8, 15, 22]\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "auto\n",
      "../pretrained_models/SRFlow_DF2K_4X.pth\n"
     ]
    }
   ],
   "source": [
    "#choose the pre-trained SRFlow model yml file in 'conf_path' variable\n",
    "conf_path = 'confs/SRFlow_CelebA_8X.yml'\n",
    "conf_path2='confs/SRFlow_DF2K_4X.yml'\n",
    "modelX8, opt1 = load_model(conf_path)\n",
    "modelX4,opt2=load_model(conf_path2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-02T13:22:47.506804Z",
     "start_time": "2024-04-02T13:22:47.148766Z"
    },
    "id": "Pderd41woBBQ"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]\n",
      "Loading model from: /home/yuanbangliang/anaconda3/envs/p_r_hubs/lib/python3.8/site-packages/lpips/weights/v0.1/alex.pth\n"
     ]
    }
   ],
   "source": [
    "#initialize LPIPS class\n",
    "def print_val(variable_name,vals_):\n",
    "    # variable_name = [name for name, value in globals().items() if value is vals_][0]\n",
    "    # print(f\"Variable name using globals(): {variable_name}\")\n",
    "    vals=np.array(vals_)\n",
    "    print(variable_name,\"-- min:\",str(vals.min()),\" max:\",str(vals.max()),\" avg:\",str(vals.mean()))\n",
    "    \n",
    "def IDA_RD(lr_dir,hr_dir,num_imgs):\n",
    "    lq = imread(lr_dir)\n",
    "    gt = imread(hr_dir)\n",
    "        #print(lq_paths[i])\n",
    "        \n",
    "    lq_size=lq.shape[0]\n",
    "    hq_size=gt.shape[0]\n",
    "\n",
    "    gt = torch.tensor(gt).permute(2, 0, 1)[None, ...] / 255\n",
    "    #for temperature in np.linspace(0, 1, num=11):\n",
    "    lpips_vals, ssim_vals = [], []\n",
    "    ms_ssim_vals, psnr_vals = [], []\n",
    "    \n",
    "    for j in range(num_imgs):\n",
    "    \n",
    "    # Sample a super-resolution for a low-resolution image\n",
    "        leve_size=hq_size/lq_size\n",
    "        # print(leve_size)\n",
    "        if leve_size==4:\n",
    "            # print(leve_size)\n",
    "            model=modelX4\n",
    "            sr = rgb(model.get_sr(lq=t(lq), heat=1.1))\n",
    "        elif leve_size==8:\n",
    "            model=modelX8\n",
    "            sr = rgb(model.get_sr(lq=t(lq), heat=1))\n",
    "        \n",
    "    \n",
    "        if leve_size ==32:\n",
    "            #print(1)\n",
    "            sr1 = rgb(modelX8.get_sr(lq=t(lq), heat=1))\n",
    "            sr=rgb(modelX4.get_sr(lq=t(sr1), heat=0.1))\n",
    "       \n",
    "        sr = torch.tensor(sr).permute(2, 0, 1)[None, ...] / 255.\n",
    "        # print((gt).shape, sr.shape)\n",
    "        gt = gt.cuda()\n",
    "        sr = sr.cuda()\n",
    "    \n",
    "        ssim_index = piq.ssim(gt, sr, data_range=1.)\n",
    "        psnr_index = piq.psnr(gt, sr, data_range=1., reduction='none')\n",
    "        ms_ssim_index = piq.multi_scale_ssim(gt, sr, data_range=1.)\n",
    "        lpips_index = loss_fn_alex(2*gt-1, 2*sr-1)\n",
    "    \n",
    "        ssim_vals.append(ssim_index.item())\n",
    "        psnr_vals.append(psnr_index.item())\n",
    "        lpips_vals.append(lpips_index.item())\n",
    "        ms_ssim_vals.append(ms_ssim_index.item())\n",
    "    \n",
    "    # print_val('SSIM',ssim_vals)\n",
    "    # print_val('PSNR',psnr_vals)\n",
    "    print_val('LPIPS',lpips_vals)\n",
    "    # print_val('ms-SSIM',ms_ssim_vals)\n",
    "    \n",
    "loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores\n",
    "if torch.cuda.is_available():\n",
    "    loss_fn_alex.cuda()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-02T13:15:37.721529Z",
     "start_time": "2024-04-02T13:15:37.717164Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<models.SRFlow_model.SRFlowModel object at 0x7fdd4fb13d30>\n"
     ]
    }
   ],
   "source": [
    "print(modelX8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-02T13:22:54.910859Z",
     "start_time": "2024-04-02T13:22:51.697225Z"
    },
    "id": "aPShugoBYgyX"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_118747/53249895.py:34: RuntimeWarning: invalid value encountered in cast\n",
      "  def rgb(t): return (np.clip((t[0] if len(t.shape) == 4 else t).detach().cpu().numpy().transpose([1, 2, 0]), 0, 1) * 255).astype(np.uint8)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "LPIPS -- min: 0.059829387813806534  max: 0.09941761940717697  avg: 0.07030895948410035\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm_notebook\n",
    "\n",
    "hr_dir = 'testing_images/ori_image/0898.png' \n",
    "num_imgs = 5 #number of unique HR images to compute for 1 LR image\n",
    "lr_dir = 'testing_images/lr_image/BICUBIC_256/0898.png'\n",
    "IDA_RD(lr_dir, hr_dir,num_imgs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-02T13:23:05.917689Z",
     "start_time": "2024-04-02T13:23:02.889282Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "LPIPS -- min: 0.042730726301670074  max: 0.0452091284096241  avg: 0.04426551535725594\n"
     ]
    }
   ],
   "source": [
    "lr_dir = 'testing_images/lr_image/BILINEAR_256/0898.png'\n",
    "IDA_RD(lr_dir, hr_dir,num_imgs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-02T13:23:11.565974Z",
     "start_time": "2024-04-02T13:23:08.498249Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_118747/53249895.py:34: RuntimeWarning: invalid value encountered in cast\n",
      "  def rgb(t): return (np.clip((t[0] if len(t.shape) == 4 else t).detach().cpu().numpy().transpose([1, 2, 0]), 0, 1) * 255).astype(np.uint8)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "[1, 8, 15, 22]\n",
      "True\n",
      "LPIPS -- min: 0.716343343257904  max: 0.7446193099021912  avg: 0.7297610640525818\n"
     ]
    }
   ],
   "source": [
    "lr_dir = 'testing_images/lr_image/NEAREST_256/0898.png'\n",
    "IDA_RD(lr_dir, hr_dir,num_imgs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "SRFlow.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
